package subscriptions_test

import (
	"math/big"
	"sync/atomic"
	"testing"
	"time"

	"github.com/ethereum/go-ethereum"
	"github.com/ethereum/go-ethereum/common"
	"github.com/ethereum/go-ethereum/common/hexutil"
	"github.com/onsi/gomega"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/mock"
	"github.com/stretchr/testify/require"

	"github.com/smartcontractkit/chainlink-common/pkg/logger"
	"github.com/smartcontractkit/chainlink-evm/gethwrappers/functions/generated/functions_router"
	"github.com/smartcontractkit/chainlink-evm/pkg/client/clienttest"
	"github.com/smartcontractkit/chainlink/v2/core/internal/testutils"
	"github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/functions/subscriptions"
	smocks "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/functions/subscriptions/mocks"
)

const (
	validUser   = "0x9ED925d8206a4f88a2f643b28B3035B315753Cd6"
	invalidUser = "0x6E2dc0F9DB014aE19888F539E59285D2Ea04244C"
	storedUser  = "0x3E2dc0F9DB014aE19888F539E59285D2Ea04233G"
)

func TestSubscriptions_OnePass(t *testing.T) {
	t.Parallel()
	getSubscriptionCount := hexutil.MustDecode("0x0000000000000000000000000000000000000000000000000000000000000003")
	getSubscriptionsInRange := hexutil.MustDecode("0x00000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000003000000000000000000000000000000000000000000000000000000000000006000000000000000000000000000000000000000000000000000000000000001600000000000000000000000000000000000000000000000000000000000000240000000000000000000000000000000000000000000000000de0b6b3a76400000000000000000000000000000109e6e1b12098cc8f3a1e9719a817ec53ab9b35c000000000000000000000000000000000000000000000000000034e23f515cb0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000c000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001000000000000000000000000f5340f0968ee8b7dfd97e3327a6139273cc2c4fa000000000000000000000000000000000000000000000001158e460913d000000000000000000000000000009ed925d8206a4f88a2f643b28b3035b315753cd60000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000c0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001bc14b92364c75e20000000000000000000000009ed925d8206a4f88a2f643b28b3035b315753cd60000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000c0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000010000000000000000000000005439e5881a529f3ccbffc0e82d49f9db3950aefe")

	ctx := testutils.Context(t)
	client := clienttest.NewClient(t)
	client.On("LatestBlockHeight", mock.Anything).Return(big.NewInt(42), nil)
	client.On("CallContract", mock.Anything, ethereum.CallMsg{ // getSubscriptionCount
		To:   &common.Address{},
		Data: hexutil.MustDecode("0x66419970"),
	}, mock.Anything).Return(getSubscriptionCount, nil)
	client.On("CallContract", mock.Anything, ethereum.CallMsg{ // GetSubscriptionsInRange
		To:   &common.Address{},
		Data: hexutil.MustDecode("0xec2454e500000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000003"),
	}, mock.Anything).Return(getSubscriptionsInRange, nil)
	config := subscriptions.OnchainSubscriptionsConfig{
		ContractAddress:    common.Address{},
		BlockConfirmations: 1,
		UpdateFrequencySec: 1,
		UpdateTimeoutSec:   1,
		UpdateRangeSize:    3,
	}
	orm := smocks.NewORM(t)
	orm.On("GetSubscriptions", mock.Anything, uint(0), uint(100)).Return([]subscriptions.StoredSubscription{}, nil)
	orm.On("UpsertSubscription", mock.Anything, mock.Anything).Return(nil)
	subscriptions, err := subscriptions.NewOnchainSubscriptions(client, config, orm, logger.Test(t))
	require.NoError(t, err)

	err = subscriptions.Start(ctx)
	require.NoError(t, err)
	t.Cleanup(func() {
		assert.NoError(t, subscriptions.Close())
	})

	// initially we have 3 subs and range is 3, which needs one pass
	gomega.NewGomegaWithT(t).Eventually(func() bool {
		expectedBalance := big.NewInt(0).SetBytes(hexutil.MustDecode("0x01158e460913d00000"))
		balance, err1 := subscriptions.GetMaxUserBalance(common.HexToAddress(validUser))
		_, err2 := subscriptions.GetMaxUserBalance(common.HexToAddress(invalidUser))
		return err1 == nil && err2 != nil && balance.Cmp(expectedBalance) == 0
	}, testutils.WaitTimeout(t), time.Second).Should(gomega.BeTrue())
}

func TestSubscriptions_MultiPass(t *testing.T) {
	t.Parallel()
	const ncycles int32 = 5
	var currentCycle atomic.Int32
	getSubscriptionCount := hexutil.MustDecode("0x0000000000000000000000000000000000000000000000000000000000000006")
	getSubscriptionsInRange := hexutil.MustDecode("0x00000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000003000000000000000000000000000000000000000000000000000000000000006000000000000000000000000000000000000000000000000000000000000001600000000000000000000000000000000000000000000000000000000000000240000000000000000000000000000000000000000000000000de0b6b3a76400000000000000000000000000000109e6e1b12098cc8f3a1e9719a817ec53ab9b35c000000000000000000000000000000000000000000000000000034e23f515cb0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000c000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001000000000000000000000000f5340f0968ee8b7dfd97e3327a6139273cc2c4fa000000000000000000000000000000000000000000000001158e460913d000000000000000000000000000009ed925d8206a4f88a2f643b28b3035b315753cd60000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000c0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001bc14b92364c75e20000000000000000000000009ed925d8206a4f88a2f643b28b3035b315753cd60000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000c0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000010000000000000000000000005439e5881a529f3ccbffc0e82d49f9db3950aefe")

	ctx := testutils.Context(t)
	client := clienttest.NewClient(t)
	client.On("LatestBlockHeight", mock.Anything).Return(big.NewInt(42), nil)
	client.On("CallContract", mock.Anything, ethereum.CallMsg{ // getSubscriptionCount
		To:   &common.Address{},
		Data: hexutil.MustDecode("0x66419970"),
	}, mock.Anything).Run(func(args mock.Arguments) {
		currentCycle.Add(1)
	}).Return(getSubscriptionCount, nil)
	client.On("CallContract", mock.Anything, ethereum.CallMsg{ // GetSubscriptionsInRange(1,3)
		To:   &common.Address{},
		Data: hexutil.MustDecode("0xec2454e500000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000003"),
	}, mock.Anything).Return(getSubscriptionsInRange, nil)
	client.On("CallContract", mock.Anything, ethereum.CallMsg{ // GetSubscriptionsInRange(4,6)
		To:   &common.Address{},
		Data: hexutil.MustDecode("0xec2454e500000000000000000000000000000000000000000000000000000000000000040000000000000000000000000000000000000000000000000000000000000006"),
	}, mock.Anything).Return(getSubscriptionsInRange, nil)
	config := subscriptions.OnchainSubscriptionsConfig{
		ContractAddress:    common.Address{},
		BlockConfirmations: 1,
		UpdateFrequencySec: 1,
		UpdateTimeoutSec:   1,
		UpdateRangeSize:    3,
	}
	orm := smocks.NewORM(t)
	orm.On("GetSubscriptions", mock.Anything, uint(0), uint(100)).Return([]subscriptions.StoredSubscription{}, nil)
	orm.On("UpsertSubscription", mock.Anything, mock.Anything).Return(nil)
	subscriptions, err := subscriptions.NewOnchainSubscriptions(client, config, orm, logger.Test(t))
	require.NoError(t, err)

	err = subscriptions.Start(ctx)
	require.NoError(t, err)
	t.Cleanup(func() {
		assert.NoError(t, subscriptions.Close())
	})

	gomega.NewGomegaWithT(t).Eventually(func() bool {
		return currentCycle.Load() == ncycles
	}, testutils.WaitTimeout(t), time.Second).Should(gomega.BeTrue())
}

func TestSubscriptions_Stored(t *testing.T) {
	t.Parallel()
	getSubscriptionCount := hexutil.MustDecode("0x0000000000000000000000000000000000000000000000000000000000000003")
	getSubscriptionsInRange := hexutil.MustDecode("0x00000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000003000000000000000000000000000000000000000000000000000000000000006000000000000000000000000000000000000000000000000000000000000001600000000000000000000000000000000000000000000000000000000000000240000000000000000000000000000000000000000000000000de0b6b3a76400000000000000000000000000000109e6e1b12098cc8f3a1e9719a817ec53ab9b35c000000000000000000000000000000000000000000000000000034e23f515cb0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000c000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001000000000000000000000000f5340f0968ee8b7dfd97e3327a6139273cc2c4fa000000000000000000000000000000000000000000000001158e460913d000000000000000000000000000009ed925d8206a4f88a2f643b28b3035b315753cd60000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000c0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001bc14b92364c75e20000000000000000000000009ed925d8206a4f88a2f643b28b3035b315753cd60000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000c0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000010000000000000000000000005439e5881a529f3ccbffc0e82d49f9db3950aefe")

	ctx := testutils.Context(t)
	client := clienttest.NewClient(t)
	client.On("LatestBlockHeight", mock.Anything).Return(big.NewInt(42), nil)
	client.On("CallContract", mock.Anything, ethereum.CallMsg{ // getSubscriptionCount
		To:   &common.Address{},
		Data: hexutil.MustDecode("0x66419970"),
	}, mock.Anything).Return(getSubscriptionCount, nil)
	client.On("CallContract", mock.Anything, ethereum.CallMsg{ // GetSubscriptionsInRange
		To:   &common.Address{},
		Data: hexutil.MustDecode("0xec2454e500000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000003"),
	}, mock.Anything).Return(getSubscriptionsInRange, nil)
	config := subscriptions.OnchainSubscriptionsConfig{
		ContractAddress:    common.Address{},
		BlockConfirmations: 1,
		UpdateFrequencySec: 1,
		UpdateTimeoutSec:   1,
		UpdateRangeSize:    3,
		StoreBatchSize:     1,
	}

	expectedBalance := big.NewInt(5)
	orm := smocks.NewORM(t)
	orm.On("GetSubscriptions", mock.Anything, uint(0), uint(1)).Return([]subscriptions.StoredSubscription{
		{
			SubscriptionID: 1,
			IFunctionsSubscriptionsSubscription: functions_router.IFunctionsSubscriptionsSubscription{
				Balance:        expectedBalance,
				Owner:          common.HexToAddress(storedUser),
				BlockedBalance: big.NewInt(10),
			},
		},
	}, nil)
	orm.On("GetSubscriptions", mock.Anything, uint(1), uint(1)).Return([]subscriptions.StoredSubscription{}, nil)
	orm.On("UpsertSubscription", mock.Anything, mock.Anything).Return(nil)

	subscriptions, err := subscriptions.NewOnchainSubscriptions(client, config, orm, logger.Test(t))
	require.NoError(t, err)

	err = subscriptions.Start(ctx)
	require.NoError(t, err)
	t.Cleanup(func() {
		assert.NoError(t, subscriptions.Close())
	})

	gomega.NewGomegaWithT(t).Eventually(func() bool {
		actualBalance, err := subscriptions.GetMaxUserBalance(common.HexToAddress(storedUser))
		return err == nil && assert.Equal(t, expectedBalance, actualBalance)
	}, testutils.WaitTimeout(t), time.Second).Should(gomega.BeTrue())
}
